#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# Copyright (C) 2018-2022 GEM Foundation
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake.  If not, see <http://www.gnu.org/licenses/>.

import operator
from openquake.baselib import sap, parallel, general
from openquake.hazardlib import contexts, calc
from openquake.commonlib import datastore

TWO20 = 2**20
MAX_MB = 500

'''Using shared memory

import logging
from openquake.hazardlib import probability_map

def classical(srcs, srcfilter, cmaker, pmap, monitor):
    ctxs_mb = 0
    allctxs = []
    with pmap.shared as array:
        slc = slice(cmaker.start, cmaker.start + len(cmaker.gsims))
        pmap.array = array[:, :, slc]
        for src in srcs:
            sites = srcfilter.get_close_sites(src)
            if sites is None:
                continue
            for ctx in cmaker.get_ctx_iter(src, sites):
                allctxs.append(ctx)
                ctxs_mb += ctx.nbytes / TWO20  # TWO20=1MB
                if ctxs_mb > MAX_MB:
                    yield cmaker.update(pmap, contexts.concat(allctxs))
                    allctxs.clear()
                    ctxs_mb = 0
        if ctxs_mb:
            yield cmaker.update(pmap, contexts.concat(allctxs))


def classical_indep(calc_id: int):
    """
    Classical calculator as postprocessor; works only for independent
    sources and ruptures.
    """
    parent = None if calc_id is None else datastore.read(calc_id)
    log, dstore = datastore.build_log_dstore(parent=parent)
    with dstore, log:
        oq = dstore.parent['oqparam']
        sitecol = dstore.parent['sitecol']
        srcfilter = calc.filters.SourceFilter(sitecol, oq.maximum_distance)
        cmakers = contexts.read_cmakers(dstore.parent).to_array()
        N, L, G = len(sitecol), oq.imtls.size, sum(
            len(cm.gsims) for cm in cmakers)
        logging.info(f'{N=}, {L=}, {G=}, {general.humansize(N*L*G*8)=}')
        pmap = probability_map.ProbabilityMap(sitecol.sids, L, G)
        dstore.swmr_on()
        smap = parallel.Starmap(classical, h5=dstore)
        pmap.shared = smap.create_shared((N, L, G), value=1.)
        csm = dstore['_csm']
        maxw = csm.get_max_weight(oq)
        for grp_id, sg in enumerate(csm.src_groups):
            logging.info('Sending group #%d', grp_id)
            for block in general.block_splitter(
                    sg, maxw, operator.attrgetter('weight'), sort=True):
                smap.submit((block, srcfilter, cmakers[grp_id], pmap))
        smap.reduce()
classical_indep.calc_id = 'parent calculation'
'''

def classical(srcs, srcfilter, cmaker, monitor):
    ctxs_mb = 0
    allctxs = []
    for src in srcs:
        sites = srcfilter.get_close_sites(src)
        if sites is None:
            continue
        for ctx in cmaker.get_ctx_iter(src, sites):
            allctxs.append(ctx)
            ctxs_mb += ctx.nbytes / TWO20  # TWO20=1MB
            if ctxs_mb > MAX_MB:
                yield {'ctxs_mb': ctxs_mb,
                       'pmap': cmaker.get_pmap(contexts.concat(allctxs))}
                allctxs.clear()
                ctxs_mb = 0
    if ctxs_mb:
        yield {'ctxs_mb': ctxs_mb,
               'pmap': cmaker.get_pmap(contexts.concat(allctxs))}


def classical_indep(calc_id: int):
    """
    Classical calculator as postprocessor; works only for independent
    sources and ruptures.
    """
    parent = None if calc_id is None else datastore.read(calc_id)
    log, dstore = datastore.build_log_dstore(parent=parent)
    with dstore, log:
        oq = dstore.parent['oqparam']
        sitecol = dstore.parent['sitecol']
        srcfilter = calc.filters.SourceFilter(sitecol, oq.maximum_distance)
        cmakers = contexts.read_cmakers(dstore.parent).to_array()
        dstore.swmr_on()
        smap = parallel.Starmap(classical, h5=dstore)
        csm = dstore['_csm']
        maxw = csm.get_max_weight(oq)
        for grp_id, sg in enumerate(csm.src_groups):
            for block in general.block_splitter(
                    sg, maxw, operator.attrgetter('weight'), sort=True):
                smap.submit((block, srcfilter, cmakers[grp_id]))
        ctxs_mb = 0
        for res in smap:
            ctxs_mb += res['ctxs_mb']
        print(f'{ctxs_mb=}')

classical_indep.calc_id = 'parent calculation'

if __name__ == '__main__':
    sap.run(classical_indep)
